{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# 30分钟吃掉wandb可视化模型分析\n",
    "\n",
    "wandb是\"我爱你，大baby\"首字母的缩写。\n",
    "\n",
    "顾名思义，她是炼丹师的大宝贝，是炼丹师最爱的炼丹伴侣。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "just kidding, 开个玩笑！\n",
    "\n",
    "wandb全称weights&bias，是一款类似TensorBoard\n",
    "\n",
    "的机器学习可视化分析工具。\n",
    "\n",
    "相比TensorBoard，wandb具有如下主要优势：\n",
    "\n",
    "* 日志上传云端永久存储，便于分享不怕丢失。\n",
    "\n",
    "* 可以存管代码,数据集和模型的版本，随时复现。(wandb.Artifact)\n",
    "\n",
    "* 可以使用交互式表格进行case分析(wandb.Table)\n",
    "\n",
    "* 可以自动化模型调参。(wandb.sweep)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "B站视频演示：https://www.bilibili.com/video/BV17A41167WX/\n",
    "\n",
    "官方文档： https://docs.wandb.ai/ \n",
    "\n",
    "\n",
    "\n",
    "\n",
    "![](https://tva1.sinaimg.cn/large/008vOhrAgy1hakba64ol9j31920mcjud.jpg)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "总体来说，wandb目前的核心功能有以下4个：\n",
    "\n",
    "1，实验跟踪：experiment tracking （wandb.log）\n",
    "\n",
    "2，版本管理：version management (wandb.log_artifact, wandb.save)\n",
    "\n",
    "3，case分析：case visualization (wandb.Table, wandb.Image)\n",
    "\n",
    "4，超参调优：model optimization (wandb.sweep)\n",
    "\n",
    "本文我们主要介绍 前3个能力，超参调优的介绍在另一篇文章。\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 〇，注册wandb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "使用wandb可视化模型训练过程需要在  https://wandb.ai/ 注册账户，\n",
    "\n",
    "并在个人settings页面获取 API keys。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "#import os\n",
    "#os.environ[\"WANDB_API_KEY\"] = \"xxxx\"\n",
    "\n",
    "import wandb\n",
    "wandb.login()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 一，实验跟踪"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "wandb 提供了类似 TensorBoard的实验跟踪能力，主要包括：\n",
    "\n",
    "* 模型配置超参数的记录\n",
    "\n",
    "* 模型训练过程中loss，metric等各种指标的记录和可视化\n",
    "\n",
    "* 图像的可视化(wandb.Image)\n",
    "\n",
    "* 其他各种Media(wandb.Vedio, wandb.Audio, wandb.Html, 3D点云等)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "import os,PIL \n",
    "import numpy as np\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import torch \n",
    "from torch import nn \n",
    "import torchvision \n",
    "from torchvision import transforms\n",
    "import datetime\n",
    "import wandb \n",
    "from argparse import Namespace\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "config = Namespace(\n",
    "    project_name = 'wandb_demo',\n",
    "    \n",
    "    batch_size = 512,\n",
    "    \n",
    "    hidden_layer_width = 64,\n",
    "    dropout_p = 0.1,\n",
    "    \n",
    "    lr = 1e-4,\n",
    "    optim_type = 'Adam',\n",
    "    \n",
    "    epochs = 15,\n",
    "    ckpt_path = 'checkpoint.pt'\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def create_dataloaders(config):\n",
    "    transform = transforms.Compose([transforms.ToTensor()])\n",
    "    ds_train = torchvision.datasets.MNIST(root=\"./mnist/\",train=True,download=True,transform=transform)\n",
    "    ds_val = torchvision.datasets.MNIST(root=\"./mnist/\",train=False,download=True,transform=transform)\n",
    "\n",
    "    ds_train_sub = torch.utils.data.Subset(ds_train, indices=range(0, len(ds_train), 5))\n",
    "    dl_train =  torch.utils.data.DataLoader(ds_train_sub, batch_size=config.batch_size, shuffle=True,\n",
    "                                            num_workers=2,drop_last=True)\n",
    "    dl_val =  torch.utils.data.DataLoader(ds_val, batch_size=config.batch_size, shuffle=False, \n",
    "                                          num_workers=2,drop_last=True)\n",
    "    return dl_train,dl_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def create_net(config):\n",
    "    net = nn.Sequential()\n",
    "    net.add_module(\"conv1\",nn.Conv2d(in_channels=1,out_channels=config.hidden_layer_width,kernel_size = 3))\n",
    "    net.add_module(\"pool1\",nn.MaxPool2d(kernel_size = 2,stride = 2)) \n",
    "    net.add_module(\"conv2\",nn.Conv2d(in_channels=config.hidden_layer_width,\n",
    "                                     out_channels=config.hidden_layer_width,kernel_size = 5))\n",
    "    net.add_module(\"pool2\",nn.MaxPool2d(kernel_size = 2,stride = 2))\n",
    "    net.add_module(\"dropout\",nn.Dropout2d(p = config.dropout_p))\n",
    "    net.add_module(\"adaptive_pool\",nn.AdaptiveMaxPool2d((1,1)))\n",
    "    net.add_module(\"flatten\",nn.Flatten())\n",
    "    net.add_module(\"linear1\",nn.Linear(config.hidden_layer_width,config.hidden_layer_width))\n",
    "    net.add_module(\"relu\",nn.ReLU())\n",
    "    net.add_module(\"linear2\",nn.Linear(config.hidden_layer_width,10))\n",
    "    net.to(device)\n",
    "    return net "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def train_epoch(model,dl_train,optimizer):\n",
    "    model.train()\n",
    "    for step, batch in enumerate(dl_train):\n",
    "        features,labels = batch\n",
    "        features,labels = features.to(device),labels.to(device)\n",
    "\n",
    "        preds = model(features)\n",
    "        loss = nn.CrossEntropyLoss()(preds,labels)\n",
    "        loss.backward()\n",
    "\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def eval_epoch(model,dl_val):\n",
    "    model.eval()\n",
    "    accurate = 0\n",
    "    num_elems = 0\n",
    "    for batch in dl_val:\n",
    "        features,labels = batch\n",
    "        features,labels = features.to(device),labels.to(device)\n",
    "        with torch.no_grad():\n",
    "            preds = model(features)\n",
    "        predictions = preds.argmax(dim=-1)\n",
    "        accurate_preds =  (predictions==labels)\n",
    "        num_elems += accurate_preds.shape[0]\n",
    "        accurate += accurate_preds.long().sum()\n",
    "\n",
    "    val_acc = accurate.item() / num_elems\n",
    "    return val_acc\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     13
    ],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def train(config = config):\n",
    "    dl_train, dl_val = create_dataloaders(config)\n",
    "    model = create_net(config); \n",
    "    optimizer = torch.optim.__dict__[config.optim_type](params=model.parameters(), lr=config.lr)\n",
    "    #======================================================================\n",
    "    nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "    wandb.init(project=config.project_name, config = config.__dict__, name = nowtime, save_code=True)\n",
    "    model.run_id = wandb.run.id\n",
    "    #======================================================================\n",
    "    model.best_metric = -1.0\n",
    "    for epoch in range(1,config.epochs+1):\n",
    "        model = train_epoch(model,dl_train,optimizer)\n",
    "        val_acc = eval_epoch(model,dl_val)\n",
    "        if val_acc>model.best_metric:\n",
    "            model.best_metric = val_acc\n",
    "            torch.save(model.state_dict(),config.ckpt_path)   \n",
    "        nowtime = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')\n",
    "        print(f\"epoch【{epoch}】@{nowtime} --> val_acc= {100 * val_acc:.2f}%\")\n",
    "        #======================================================================\n",
    "        wandb.log({'epoch':epoch, 'val_acc': val_acc, 'best_val_acc':model.best_metric})\n",
    "        #======================================================================        \n",
    "    #======================================================================\n",
    "    wandb.finish()\n",
    "    #======================================================================\n",
    "    return model   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "model = train(config) ##3,2,1 点火🔥🔥"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2023-01-28T13:55:16.341030Z",
     "iopub.status.busy": "2023-01-28T13:55:16.340640Z",
     "iopub.status.idle": "2023-01-28T13:55:16.346397Z",
     "shell.execute_reply": "2023-01-28T13:55:16.345309Z",
     "shell.execute_reply.started": "2023-01-28T13:55:16.340994Z"
    },
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 二，版本管理"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "除了可以记录实验日志传递到 wandb 网站的云端服务器 并进行可视化分析。\n",
    "\n",
    "wandb还能够将实验关联的数据集，代码和模型 保存到 wandb 服务器。\n",
    "\n",
    "非常便于我们以后 或者其他人 对实验结果进行复现。\n",
    "\n",
    "我们可以通过 wandb.log_artifact的方法来保存任务的关联的重要成果。\n",
    "\n",
    "例如dataset, code，和 model，并进行版本管理。\n",
    "\n",
    "**注：artifact翻译为\"工件\"，是指软件开发中产出的最终成果.** \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "我们先使用run_id 恢复 run任务，以便继续记录。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "#resume the run \n",
    "import wandb \n",
    "\n",
    "run = wandb.init(project='wandb_demo', id= model.run_id, resume='must')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# save dataset \n",
    "arti_dataset = wandb.Artifact('mnist', type='dataset')\n",
    "arti_dataset.add_dir('mnist/')\n",
    "wandb.log_artifact(arti_dataset)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# save code \n",
    "\n",
    "arti_code = wandb.Artifact('ipynb', type='code')\n",
    "arti_code.add_file('./30分钟吃掉wandb可视化模型分析.ipynb')\n",
    "wandb.log_artifact(arti_code)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# save model\n",
    "\n",
    "arti_model = wandb.Artifact('cnn', type='model')\n",
    "arti_model.add_file(config.ckpt_path)\n",
    "wandb.log_artifact(arti_model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "wandb.finish() #finish时会提交保存"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "## 三，Case分析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "source": [
    "利用 wandb.Table 我们可以在 wandb的 dashboard 进行交互式可视化的 case分析。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true,
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "#resume the run \n",
    "import wandb \n",
    "run = wandb.init(project=config.project_name, id= model.run_id, resume='must')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [
     8,
     13
    ],
    "scrolled": true,
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt \n",
    "\n",
    "transform = transforms.Compose([transforms.ToTensor()])\n",
    "ds_train = torchvision.datasets.MNIST(root=\"./mnist/\",train=True,download=True,transform=transform)\n",
    "ds_val = torchvision.datasets.MNIST(root=\"./mnist/\",train=False,download=True,transform=transform)\n",
    "    \n",
    "# visual the  prediction\n",
    "device = None\n",
    "for p in model.parameters():\n",
    "    device = p.device\n",
    "    break\n",
    "\n",
    "plt.figure(figsize=(8,8)) \n",
    "for i in range(9):\n",
    "    img,label = ds_val[i]\n",
    "    tensor = img.to(device)\n",
    "    y_pred = torch.argmax(model(tensor[None,...])) \n",
    "    img = img.permute(1,2,0)\n",
    "    ax=plt.subplot(3,3,i+1)\n",
    "    ax.imshow(img.numpy())\n",
    "    ax.set_title(\"y_pred = %d\"%y_pred)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([]) \n",
    "plt.show()    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "def data2fig(data):\n",
    "    import matplotlib.pyplot as plt \n",
    "    fig = plt.figure()\n",
    "    ax = fig.add_subplot()\n",
    "    ax.imshow(data)\n",
    "    ax.set_xticks([])\n",
    "    ax.set_yticks([]) \n",
    "    return fig\n",
    "\n",
    "def fig2img(fig):\n",
    "    import io,PIL\n",
    "    buf = io.BytesIO()\n",
    "    fig.savefig(buf)\n",
    "    buf.seek(0)\n",
    "    img = PIL.Image.open(buf)\n",
    "    return img\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm \n",
    "good_cases = wandb.Table(columns = ['Image','GroundTruth','Prediction'])\n",
    "bad_cases = wandb.Table(columns = ['Image','GroundTruth','Prediction'])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "code_folding": [],
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "# 找到50个good cases 和 50 个bad cases\n",
    "\n",
    "plt.close()\n",
    "\n",
    "for i in tqdm(range(1000)):\n",
    "    features,label = ds_val[i]\n",
    "    tensor = features.to(device)\n",
    "    y_pred = torch.argmax(model(tensor[None,...])) \n",
    "    \n",
    "    # log badcase\n",
    "    if y_pred!=label:\n",
    "        if len(bad_cases.data)<50:\n",
    "            data = features.permute(1,2,0).numpy()\n",
    "            input_img = wandb.Image(fig2img(data2fig(data)))\n",
    "            bad_cases.add_data(input_img,label,y_pred)\n",
    "            \n",
    "    # log goodcase\n",
    "    else:\n",
    "        if len(good_cases.data)<50:\n",
    "            data = features.permute(1,2,0).numpy()\n",
    "            input_img = wandb.Image(fig2img(data2fig(data)))\n",
    "            good_cases.add_data(input_img,label,y_pred)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [],
   "source": [
    "wandb.log({'good_cases':good_cases,'bad_cases':bad_cases})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "slideshow": {
     "slide_type": "-"
    }
   },
   "outputs": [],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "然后就可以在 DashBoard中对Table 进行交互式的分析了，\n",
    "\n",
    "包括常用的 sort, filter, group 等操作，妥妥的一个网页版的 Excel。 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "![算法美食屋logo.png](https://tva1.sinaimg.cn/large/e6c9d24egy1h41m2zugguj20k00b9q46.jpg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
